import os
from argparse import ArgumentParser

import numpy as np


def get_args():
    parser = ArgumentParser()
    parser.add_argument("--text", type=str, help="Path to text", required=True)
    parser.add_argument("--out-dir", type=str, required=True)
    return parser.parse_args()


# 56% train, 19% val, 25% test

train_spks = [
    # mild: 113
    "kurland12a",
    "scale05c",
    "kurland05a",
    "scale17a",
    "fridriksson02a",
    "kurland09b",
    "tap04a",
    "adler25a",
    "elman15a",
    "MMA12a",
    "MSU02b",
    "kurland04b",
    "kurland07b",
    "elman01a",
    "kurland02c",
    "kurland03b",
    "tucson16a",
    "wozniak05a",
    "fridriksson09a",
    "williamson17a",
    "adler20a",
    "fridriksson05a",
    "ACWT04a",
    "adler14a",
    "williamson15a",
    "wozniak04a",
    "kurland04a",
    "scale35a",
    "scale21a",
    "MSU02a",
    "star03a",
    "fridriksson07a",
    "whiteside18a",
    "tucson06a",
    "kansas19a",
    "adler01a",
    "kurland08a",
    "kurland14a",
    "scale06b",
    "kurland17b",
    "MMA05a",
    "wozniak01a",
    "UNH03a",
    "kansas15a",
    "williamson09a",
    "tcu10b",
    "ACWT07a",
    "BU01a",
    "adler24a",
    "adler07a",
    "williamson08a",
    "MMA09a",
    "scale12b",
    "kansas11a",
    "tucson06b",
    "kurland10b",
    "kurland02d",
    "tap01a",
    "adler08a",
    "UNH01a",
    "kurland14b",
    "kurland25a",
    "kurland23c",
    "BU05a",
    "whiteside06a",
    "scale06a",
    "fridriksson01a",
    "elman05a",
    "MMA02a",
    "UNH11a",
    "elman04a",
    "MMA03a",
    "elman13a",
    "williamson05a",
    "whiteside09a",
    "kansas21a",
    "UNH01c",
    "scale06c",
    "UNH05a",
    "kempler02a",
    "scale22a",
    "kansas04a",
    "tucson18a",
    "elman10a",
    "whiteside20a",
    "BU04a",
    "kurland12b",
    "tucson10a",
    "williamson14a",
    "whiteside17a",
    "wright202a",
    "tcu01a",
    "whiteside19a",
    "cmu03a",
    "williamson18a",
    "scale34a",
    "kurland08b",
    "kurland17c",
    "thompson02a",
    "scale18d",
    "fridriksson09b",
    "fridriksson04a",
    "tucson20a",
    "kurland26b",
    "williamson10a",
    "wozniak06a",
    "williamson07a",
    "kurland28c",
    "whiteside13a",
    "MSU01a",
    "thompson14a",
    "adler12a",
    "thompson01a",
    # moderate: 83
    "MSU08b",
    "ACWT05a",
    "whiteside02a",
    "UNH08a",
    "kempler04a",
    "fridriksson10b",
    "thompson03a",
    "tcu03a",
    "tucson15b",
    "UNH10a",
    "williamson03a",
    "kansas22a",
    "scale26a",
    "cmu02b",
    "fridriksson06b",
    "BU09a",
    "garrett01a",
    "scale15b",
    "williamson16a",
    "whiteside10a",
    "ACWT02a",
    "scale10a",
    "tap02a",
    "scale23a",
    "adler10a",
    "fridriksson10a",
    "wright201a",
    "adler04a",
    "fridriksson12a",
    "tucson07a",
    "kurland01c",
    "tucson11a",
    "scale11a",
    "wright206a",
    "scale11b",
    "scale15c",
    "williamson19a",
    "kurland29a",
    "adler13a",
    "scale18c",
    "kurland29c",
    "scale13a",
    "scale30b",
    "kurland24b",
    "adler18a",
    "williamson01a",
    "adler02a",
    "kurland02b",
    "kurland20a",
    "thompson05a",
    "wozniak07a",
    "scale36a",
    "adler05a",
    "kansas20a",
    "MSU05a",
    "tucson08a",
    "wozniak03a",
    "tucson22a",
    "williamson06a",
    "MSU07a",
    "scale19a",
    "MMA16a",
    "BU07a",
    "whiteside12a",
    "whiteside08a",
    "scale05b",
    "kurland01d",
    "scale18a",
    "MSU03b",
    "BU02a",
    "scale04a",
    "tap14a",
    "kansas17a",
    "MMA13a",
    "adler15a",
    "tucson09a",
    "kansas09a",
    "williamson24a",
    "MSU07b",
    "scale12a",
    "thompson04a",
    "kansas13a",
    "kurland19b",
    # severe: 34
    "scale25a",
    "kurland15c",
    "BU08a",
    "whiteside03a",
    "kansas05a",
    "kansas01a",
    "scale28a",
    "UNH04a",
    "tcu02b",
    "scale09a",
    "fridriksson03a",
    "MMA14a",
    "kansas08a",
    "adler06a",
    "kansas12a",
    "ACWT11a",
    "tucson14a",
    "scale24a",
    "tap09a",
    "tucson15a",
    "tap13a",
    "kurland22b",
    "kansas02a",
    "kansas06a",
    "whiteside04a",
    "ACWT08a",
    "williamson21a",
    "tucson03a",
    "elman06a",
    "elman08a",
    "scale27a",
    "fridriksson06a",
    "adler19a",
    "adler23a",
    # very severe: 9
    "kurland15c",
    "kansas01a",
    "scale09a",
    "kansas08a",
    "tap09a",
    "kansas02a",
    "kansas06a",
    "williamson21a",
    "adler19a",
    # control: 268
    "capilouto30a",
    "UMD05",
    "richardson191",
    "richardson202",
    "capilouto52a",
    "wright79a",
    "capilouto58a",
    "wright93a",
    "capilouto57a",
    "capilouto35a",
    "richardson177",
    "wright98a",
    "wright82a",
    "UMD19",
    "wright55a",
    "richardson25",
    "wright33a",
    "capilouto77a",
    "capilouto53a",
    "wright99a",
    "MSUC08b",
    "capilouto39a",
    "capilouto08a",
    "MSUC07a",
    "wright29a",
    "capilouto64a",
    "MSUC04b",
    "wright43a",
    "capilouto44a",
    "capilouto11a",
    "capilouto65a",
    "capilouto42a",
    "UMD21",
    "UMD04",
    "capilouto79a",
    "MSUC02a",
    "richardson178",
    "wright04a",
    "capilouto23a",
    "MSUC03b",
    "capilouto09a",
    "wright21a",
    "MSUC08a",
    "wright38a",
    "capilouto60a",
    "richardson22",
    "richardson175",
    "capilouto02a",
    "wright80a",
    "capilouto28a",
    "richardson20",
    "richardson60",
    "capilouto24a",
    "wright24a",
    "MSUC05a",
    "capilouto19a",
    "capilouto61a",
    "capilouto63a",
    "UMD22",
    "wright85a",
    "capilouto56a",
    "richardson200",
    "wright69a",
    "richardson38",
    "capilouto06a",
    "richardson195",
    "wright14a",
    "wright10a",
    "capilouto54a",
    "wright09a",
    "capilouto62a",
    "richardson165",
    "wright68a",
    "capilouto67a",
    "UMD08",
    "richardson37",
    "richardson171",
    "UMD03",
    "UMD18",
    "UMD17",
    "MSUC02b",
    "richardson204",
    "capilouto14a",
    "capilouto03a",
    "richardson192",
    "wright62a",
    "MSUC07b",
    "wright06a",
    "MSUC03a",
    "capilouto51a",
    "UMD10",
    "richardson41",
    "wright48a",
    "wright100a",
    "wright97a",
    "wright59a",
    "richardson59",
    "UMD11",
    "UMD14",
    "richardson205",
    "richardson168",
    "richardson194",
    "wright30a",
    "wright95a",
    "wright07a",
    "wright12a",
    "wright50a",
    "wright87a",
    "capilouto32a",
    "UMD23",
    "capilouto29a",
    "UMD20",
    "capilouto18a",
    "MSUC09b",
    "richardson189",
    "wright90a",
    "capilouto13a",
    "capilouto46a",
    "richardson203",
    "wright77a",
    "wright46a",
    "wright58a",
    "richardson197",
    "wright20a",
    "wright16a",
    "capilouto05a",
    "wright26a",
    "wright51a",
    "capilouto12a",
    "richardson24",
    "wright40a",
    "wright31a",
    "UMD09",
    "wright91a",
    "richardson173",
    "richardson23",
    "wright11a",
    "wright71a",
    "capilouto21a",
    "capilouto17a",
    "UMD02",
    "richardson36",
    "richardson188",
    "richardson199",
    "capilouto10a",
    "wright63a",
    "wright89a",
    "MSUC06a",
    "wright61a",
    "richardson167",
    "capilouto16a",
    "wright02a",
    "wright88a",
    "wright45a",
    "MSUC04a",
    "UMD24",
    "wright03a",
    "richardson184",
    "wright84a",
    "UMD15",
    "richardson179",
    "wright39a",
    "richardson35",
    "capilouto01a",
    "capilouto34a",
    "capilouto15a",
    "wright81a",
    "wright53a",
    "wright94a",
    "wright32a",
    "wright08a",
    "MSUC06b",
    "richardson166",
    "richardson186",
    "richardson206",
    "wright66a",
    "wright57a",
    "wright60a",
    "richardson176",
    "wright17a",
    "richardson17",
    "capilouto68a",
    "wright65a",
    "wright19a",
    "capilouto38a",
    "MSUC09a",
    "wright96a",
    "capilouto45a",
    "capilouto20a",
    "MSUC01a",
    "kempler01a",
    "wright42a",
    "richardson18",
    "richardson58",
    "richardson172",
    "richardson19",
    "wright01a",
    "wright78a",
    "MSUC01b",
    "richardson39",
    "capilouto55a",
    "richardson185",
    "wright27a",
    "wright05a",
    "UMD01",
    "capilouto37a",
    "wright15a",
    "UMD13",
    "wright34a",
    "wright18a",
    "capilouto27a",
    "wright67a",
    "wright47a",
    "wright35a",
    "wright101a",
    "wright72a",
    "capilouto43a",
    "richardson170",
    "wright49a",
    "capilouto36a",
    "capilouto25a",
    "wright92a",
    "wright73a",
    "wright64a",
    "capilouto33a",
    "capilouto47a",
    "capilouto80a",
    "UMD12",
    "richardson21",
    "wright23a",
    "wright102a",
    "wright74a",
    "richardson92",
    "richardson196",
    "MSUC05b",
    "capilouto31a",
    "richardson42",
    "UMD16",
    "capilouto40a",
    "richardson169",
    "capilouto41a",
    "richardson201",
    "wright28a",
    "wright13a",
    "wright25a",
    "wright86a",
    "wright83a",
    "capilouto50a",
    "capilouto59a",
    "UMD06",
    "capilouto04a",
    "richardson54",
    "capilouto66a",
    "richardson198",
    "wright36a",
    "capilouto48a",
    "capilouto07a",
    "capilouto78a",
    "capilouto22a",
    "wright70a",
    "richardson174",
    "wright52a",
    "capilouto49a",
    "capilouto26a",
    "richardson34",
    "wright37a",
    "wright75a",
    "wright22a",
]

val_spks = [
    # mild: 38
    "adler09a",
    "kurland21c",
    "kurland28a",
    "MSU06b",
    "kansas18a",
    "adler17a",
    "kansas07a",
    "MSU01b",
    "kurland03a",
    "scale20a",
    "whiteside01a",
    "scale17c",
    "MMA22a",
    "adler21a",
    "kurland06b",
    "kurland17a",
    "kurland25b",
    "UNH06a",
    "williamson13a",
    "kurland07a",
    "MBA01a",
    "MMA19a",
    "UCL04a",
    "wright203a",
    "tucson08b",
    "scale02b",
    "scale32a",
    "kansas03a",
    "whiteside05a",
    "scale06d",
    "ACWT09a",
    "tcu05a",
    "tap07a",
    "BU12a",
    "thompson06a",
    "scale16a",
    "tucson04a",
    "kurland21b",
    # moderate: 28
    "whiteside14a",
    "tap10a",
    "MSU08a",
    "MMA15a",
    "tap15a",
    "scale30a",
    "UNH09a",
    "kurland24c",
    "kurland24a",
    "kurland27b",
    "UCL02a",
    "elman12a",
    "kurland19a",
    "wright207a",
    "tap17a",
    "tcu08a",
    "williamson04a",
    "tap11a",
    "ACWT10a",
    "BU10a",
    "elman07a",
    "kurland29b",
    "UCL01a",
    "wright205a",
    "kempler03a",
    "ACWT03a",
    "scale33a",
    "kurland27a",
    # severe: 12
    "kurland18b",
    "UCL03a",
    "MMA10a",
    "kurland22a",
    "kurland16b",
    "kurland15a",
    "williamson23a",
    "scale07a",
    "UNH02b",
    "MMA08a",
    "kansas16a",
    "kurland18a",
    # very severe: 3
    "MMA10a",
    "kurland15a",
    "scale07a",
]

test_spks = [
    # mild: 51
    "ACWT12a",
    "MSU06a",
    "adler22a",
    "elman11b",
    "UNH01b",
    "tucson01a",
    "elman01b",
    "kurland26a",
    "williamson02a",
    "tcu09a",
    "MSU04b",
    "thompson07c",
    "scale12c",
    "BU06a",
    "whiteside07a",
    "kurland10a",
    "tap12a",
    "kurland26c",
    "wright204a",
    "fridriksson11a",
    "kurland05b",
    "MSU04a",
    "tap05a",
    "kurland02e",
    "MMA11a",
    "thompson11a",
    "kurland23a",
    "williamson14b",
    "thompson07a",
    "thompson13a",
    "adler03a",
    "kurland06a",
    "thompson10a",
    "kurland28b",
    "williamson09b",
    "kurland21a",
    "tap18a",
    "MMA17a",
    "BU03a",
    "thompson09a",
    "scale15d",
    "thompson07b",
    "wozniak02a",
    "kurland09a",
    "kurland23b",
    "scale08a",
    "thompson12a",
    "kurland25c",
    "garrett02a",
    "thompson08a",
    "scale12d",
    # moderate: 37
    "elman02a",
    "fridriksson13a",
    "MSU03a",
    "kurland13b",
    "ACWT01a",
    "tap19a",
    "scale15a",
    "scale18b",
    "elman11a",
    "scale01a",
    "kansas14a",
    "scale31a",
    "scale02a",
    "scale05a",
    "williamson11a",
    "kurland13a",
    "elman03a",
    "whiteside15a",
    "williamson12a",
    "scale14c",
    "elman14a",
    "kurland02a",
    "adler16a",
    "tap16a",
    "kurland19c",
    "whiteside16a",
    "tucson13a",
    "williamson12b",
    "elman09a",
    "kurland27c",
    "kansas10a",
    "tap08a",
    "scale14a",
    "whiteside11a",
    "MSU05b",
    "kansas23a",
    "scale38a",
    # severe: 15
    "tucson19a",
    "kurland16a",
    "adler11a",
    "tcu07a",
    "UNH07a",
    "tap06a",
    "fridriksson08b",
    "kurland15b",
    "tucson12a",
    "scale03a",
    "MMA06a",
    "tap03a",
    "fridriksson03b",
    "tucson02a",
    "UNH02a",
    # very severe: 4
    "adler11a",
    "fridriksson08b",
    "kurland15b",
    "UNH02a",
]

spk_splits = [train_spks, val_spks, test_spks]


def main():
    args = get_args()

    # get all speakers
    spk2utts = {}
    utt2trans = {}
    with open(args.text, encoding="utf-8") as f:
        for line in f:
            utt, trans = line.rstrip("\n").split(maxsplit=1)
            spk, _ = utt.split("-")
            spk2utts.setdefault(spk, []).append(utt)
            utt2trans[utt] = trans

    splits = ["train", "val", "test"]
    out_dir = args.out_dir

    # print percentage of speakers in each split
    n_spks = np.asarray([len(train_spks), len(val_spks), len(test_spks)], dtype=float)
    n_spks /= np.sum(n_spks)
    print(f"Percentage of train, val and test speakers: {n_spks}")

    # make sure there is no overlap
    assert not bool(set(train_spks) & set(test_spks)), set(train_spks) & set(test_spks)
    assert not bool(set(train_spks) & set(val_spks)), set(train_spks) & set(val_spks)
    assert not bool(set(test_spks) & set(val_spks)), set(test_spks) & set(val_spks)

    for i, s in enumerate(splits):
        subset_dir = os.path.join(out_dir, s)
        os.makedirs(subset_dir, exist_ok=True)

        utt_list = open(os.path.join(subset_dir, "utt.list"), "w", encoding="utf-8")
        text = open(os.path.join(subset_dir, "text"), "w", encoding="utf-8")
        utt2spk = open(os.path.join(subset_dir, "utt2spk"), "w", encoding="utf-8")

        spks = spk_splits[i]
        for spk in spks:
            if spk not in spk2utts:
                print(
                    f"Skipping utterances of {spk}"
                    f"since they are not found in {args.text}"
                )
                continue

            for utt in spk2utts[spk]:
                utt_list.write(f"{utt}\n")
                text.write(f"{utt}\t{utt2trans[utt]}\n")
                utt2spk.write(f"{utt}\t{spk}\n")

        utt_list.close()
        text.close()
        utt2spk.close()


if __name__ == "__main__":
    main()
